#  Licensed to the Apache Software Foundation (ASF) under one
#  or more contributor license agreements.  See the NOTICE file
#  distributed with this work for additional information
#  regarding copyright ownership.  The ASF licenses this file
#  to you under the Apache License, Version 2.0 (the
#  "License"); you may not use this file except in compliance
#  with the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing,
#  software distributed under the License is distributed on an
#  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#  KIND, either express or implied.  See the License for the
#  specific language governing permissions and limitations
#  under the License.

import pytest
import os.path
import json

from shapely.geometry import Point
from shapely.geometry import LineString
from shapely.geometry.base import BaseGeometry
from shapely.wkt import loads as wkt_loads
import geopandas

from tests.test_base import TestBase
from tests import geoparquet_input_location
from tests import plain_parquet_input_location
from tests import legacy_parquet_input_location


class TestGeoParquet(TestBase):
    def test_interoperability_with_geopandas(self, tmp_path):
        test_data = [
            [1, Point(0, 0), LineString([(1, 2), (3, 4), (5, 6)])],
            [2, LineString([(1, 2), (3, 4), (5, 6)]), Point(1, 1)],
            [3, Point(1, 1), LineString([(1, 2), (3, 4), (5, 6)])]
        ]
        df = self.spark.createDataFrame(data=test_data, schema=["id", "g0", "g1"]).repartition(1)
        geoparquet_save_path = os.path.join(tmp_path, "test.parquet")
        df.write.format("geoparquet").save(geoparquet_save_path)

        # Load geoparquet file written by sedona using geopandas
        gdf = geopandas.read_parquet(geoparquet_save_path)
        assert gdf.dtypes['g0'].name == 'geometry'
        assert gdf.dtypes['g1'].name == 'geometry'

        # Load geoparquet file written by geopandas using sedona
        gdf = geopandas.GeoDataFrame([
            {'g': wkt_loads('POINT (1 2)'), 'i': 10},
            {'g': wkt_loads('LINESTRING (1 2, 3 4)'), 'i': 20}
        ], geometry='g')
        geoparquet_save_path2 = os.path.join(tmp_path, "test_2.parquet")
        gdf.to_parquet(geoparquet_save_path2)
        df2 = self.spark.read.format("geoparquet").load(geoparquet_save_path2)
        assert df2.count() == 2
        row = df2.collect()[0]
        assert isinstance(row['g'], BaseGeometry)

    def test_load_geoparquet_with_spatial_filter(self):
        df = self.spark.read.format("geoparquet").load(geoparquet_input_location)\
            .where("ST_Contains(geometry, ST_GeomFromText('POINT (35.174722 -6.552465)'))")
        rows = df.collect()
        assert len(rows) == 1
        assert rows[0]['name'] == 'Tanzania'

    def test_load_plain_parquet_file(self):
        with pytest.raises(Exception) as excinfo:
            self.spark.read.format("geoparquet").load(plain_parquet_input_location)
        assert "does not contain valid geo metadata" in str(excinfo.value)

    def test_inspect_geoparquet_metadata(self):
        df = self.spark.read.format("geoparquet.metadata").load(geoparquet_input_location)
        rows = df.collect()
        assert len(rows) == 1
        row = rows[0]
        assert row['path'].endswith('.parquet')
        assert len(row['version'].split('.')) == 3
        assert row['primary_column'] == 'geometry'
        column_metadata = row['columns']['geometry']
        assert column_metadata['encoding'] == 'WKB'
        assert len(column_metadata['bbox']) == 4
        assert isinstance(json.loads(column_metadata['crs']), dict)

    def test_reading_legacy_parquet_files(self):
        df = self.spark.read.format("geoparquet").option("legacyMode", "true").load(legacy_parquet_input_location)
        rows = df.collect()
        assert len(rows) > 0
        for row in rows:
            assert isinstance(row['geom'], BaseGeometry)
            assert isinstance(row['struct_geom']['g0'], BaseGeometry)
            assert isinstance(row['struct_geom']['g1'], BaseGeometry)
